from abc import ABCMeta, abstractmethod
import torch
import torch.distributed as dist
from xtuner._lite import get_device


DEVICE = get_device()

class BaseRecorder:

    def __init__(self, source):
        super().__init__()
        self.source = source
        self._recording = True

    @property
    def recording(self):
        return self._recording
    
    def disable(self):
        self._recording = False
    
    def enable(self):
        self._recording = True

    def prepare_model(self, model) -> None:
        founded = False

        for name, module in model.named_modules():
            if name.endswith(self.source):
                module.register_forward_hook(self.forward_hook)
                module.name = name
                founded = True

        assert founded, f'"{self.source}" is not in the model.'

    @abstractmethod
    def forward_hook(self, module, inputs, outputs) -> None:
        pass


class GateRecorder(BaseRecorder):
    def __init__(self, source='.gate'):
        super().__init__(source=source)
        self.logit_before_gate_max = []
        self.logit_before_gate_min = []
        self.logit_before_gate_mean = []
    
    def reset(self):
        self.logit_before_gate_max = []
        self.logit_before_gate_min = []
        self.logit_before_gate_mean = []
    
    def forward_hook(self, module, inputs, outputs) -> None:
        if self.recording:
            with torch.no_grad():
                if isinstance(inputs, tuple):
                    inputs = inputs[0]
                self.logit_before_gate_max.append(inputs.max())
                self.logit_before_gate_min.append(inputs.min())
                self.logit_before_gate_mean.append(inputs.mean())
    
    def reduce_after_iter(self):
        logit_before_gate_max = torch.stack(self.logit_before_gate_max, dim=0).to(DEVICE)
        logit_before_gate_min = torch.stack(self.logit_before_gate_min, dim=0).to(DEVICE)
        logit_before_gate_mean = torch.stack(self.logit_before_gate_mean, dim=0).to(DEVICE)

        # 清除上一个 iter 的内容
        self.reset()

        # dist.all_reduce(logit_before_gate_max, op=dist.ReduceOp.MAX)
        # dist.all_reduce(logit_before_gate_min, op=dist.ReduceOp.MIN)
        # dist.all_reduce(logit_before_gate_mean, op=dist.ReduceOp.AVG)

        return logit_before_gate_max, logit_before_gate_min, logit_before_gate_mean


class ExpertActivationRecorder(BaseRecorder):
    def __init__(self, source='.experts'):
        super().__init__(source=source)
        self.expert_activation = []

    def reset(self):
        self.expert_activation = []

    def forward_hook(self, module, inputs, outputs) -> None:
        if self.recording:
            with torch.no_grad():
                assert isinstance(inputs, tuple)
                tokens_per_expert = inputs[1]
                expert_activation = (tokens_per_expert > 0).int()
                self.expert_activation.append(expert_activation)
    
    def reduce_after_iter(self):
        # ne, layers
        expert_activation = torch.stack(self.expert_activation, dim=1).to(DEVICE)

        # 清除上一个 iter 的内容
        self.reset()

        # dist.all_reduce(expert_activation, op=dist.ReduceOp.SUM)

        # (num_layers, )
        expert_activation_min = expert_activation.min(dim=0).values.to("cpu").tolist()
        expert_activation_max = expert_activation.max(dim=0).values.to("cpu").tolist()
        expert_activation_mean = expert_activation.float().mean(dim=0).to("cpu").tolist()

        return expert_activation_min, expert_activation_max, expert_activation_mean
